# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------
import torch.distributed as dist
import torch
import torch.nn as nn
import numpy as np
import math
from typing import Tuple
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
from torch.nn.functional import scaled_dot_product_attention
from torch_scatter import scatter_mean
from torch.distributed import nn as dist_nn
import torch.nn.functional as F
from torch.jit import Final
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
    trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
    get_act_layer, get_norm_layer, LayerType
from torch.distributed import nn as dist_nn
import torch.nn.init as init
import torch_scatter


def modulate(x, shift, scale):
    if len(shift.shape)==2:
        return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
    return x * (1 + scale) + shift

# @torch.no_grad()
def concat_all_gather(tensor):
    # 收集所有进程的张量
    gathered = dist_nn.functional.all_gather(tensor) 
    
    # 拼接结果 (自动保留梯度)
    return torch.cat(gathered, dim=0)

# def concat_all_gather(tensor):
#     """
#     Performs all_gather operation on the provided tensors.
#     *** Warning ***: torch.distributed.all_gather has no gradient.
#     """
#     tensors_gather = [
#         torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
#     ]
#     torch.distributed.all_gather(tensors_gather, tensor.contiguous(), async_op=False)

#     output = torch.cat(tensors_gather, dim=0)
#     return output
#################################################################################
#               Embedding Layers for Timesteps and Class Labels                 #
#################################################################################

class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb


class LabelEmbedder(nn.Module):
    """
    Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
    """
    def __init__(self, num_classes, hidden_size, dropout_prob):
        super().__init__()
        use_cfg_embedding = dropout_prob > 0
        self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
        self.num_classes = num_classes
        self.dropout_prob = dropout_prob

    def token_drop(self, labels, force_drop_ids=None):
        """
        Drops labels to enable classifier-free guidance.
        """
        if force_drop_ids is None:
            drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
        else:
            drop_ids = force_drop_ids == 1
        labels = torch.where(drop_ids, self.num_classes, labels)
        return labels

    def forward(self, labels, train, force_drop_ids=None):
        use_dropout = self.dropout_prob > 0
        if (train and use_dropout) or (force_drop_ids is not None):
            labels = self.token_drop(labels, force_drop_ids)
        embeddings = self.embedding_table(labels)
        return embeddings


#################################################################################
#                                 Core SiT Model                                #
#################################################################################

class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
    def forward(self, x):
        x =  self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
        return x

def precompute_freqs_cis_2d(dim: int, height: int, width:int, theta: float = 10000.0, scale=16.0):
    # assert  H * H == end
    # flat_patch_pos = torch.linspace(-1, 1, end) # N = end
    x_pos = torch.linspace(0, scale, width)
    y_pos = torch.linspace(0, scale, height)
    y_pos, x_pos = torch.meshgrid(y_pos, x_pos, indexing="ij")
    y_pos = y_pos.reshape(-1)
    x_pos = x_pos.reshape(-1)
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) # Hc/4
    x_freqs = torch.outer(x_pos, freqs).float() # N Hc/4
    y_freqs = torch.outer(y_pos, freqs).float() # N Hc/4
    x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs)
    y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs)
    freqs_cis = torch.cat([x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1) # N,Hc/4,2
    freqs_cis = freqs_cis.reshape(height*width, -1)
    return freqs_cis


def apply_rotary_emb(
        xq: torch.Tensor,
        xk: torch.Tensor,
        freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    freqs_cis = freqs_cis[None, :, None, :]
    # xq : B N H Hc
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # B N H Hc/2
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # B, N, H, Hc
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

class RAttention(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = True,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = RMSNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x: torch.Tensor, pos, mask) -> torch.Tensor:
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 1, 3, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # B N H Hc
        q = self.q_norm(q)
        k = self.k_norm(k)
        q, k = apply_rotary_emb(q, k, freqs_cis=pos)
        q = q.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2)  # B, H, N, Hc
        k = k.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()  # B, H, N, Hc
        v = v.view(B, -1, self.num_heads, C // self.num_heads).transpose(1, 2).contiguous()

        x = scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class DDTBlock(nn.Module):
    def __init__(self, hidden_size, groups,  mlp_ratio=4.0, ):
        super().__init__()
        self.norm1 = RMSNorm(hidden_size, eps=1e-6)
        self.attn = RAttention(hidden_size, num_heads=groups, qkv_bias=False)
        self.norm2 = RMSNorm(hidden_size, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.mlp = FeedForward(hidden_size, mlp_hidden_dim)
        self.adaLN_modulation = nn.Sequential(
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x,  c, pos, mask=None):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
        x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos, mask=mask)
        x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x

class SiTBlock(nn.Module):
    """
    A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )


    def forward(self, x, c, fl=False):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=-1)
        if fl:
            x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
            x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        else:
            x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
            x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))

        return x

class CrossAttention(nn.Module):
    fused_attn: Final[bool]

    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        self.fused_attn = use_fused_attn()

        self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
        self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
        self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        B, N_q, C = q.shape
        B, N_k, C = k.shape
        B, N_v, C = v.shape

        q = self.q_proj(q).reshape(B, N_q, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.k_proj(k).reshape(B, N_k, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.v_proj(v).reshape(B, N_v, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        q, k = self.q_norm(q), self.k_norm(k)

        if self.fused_attn:
            x = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p=self.attn_drop.p if self.training else 0.,
            )
        else:
            q = q * self.scale
            attn = q @ k.transpose(-2, -1)
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)
            x = attn @ v

        x = x.transpose(1, 2).reshape(B, N_q, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class CrossAttentionTransformer(nn.Module):
    def __init__(self, dim=384, num_heads=8, ffn_expansion=4):
        super().__init__()
        self.attn = CrossAttention(dim, num_heads)
        
        # 层归一化
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * ffn_expansion),
            nn.GELU(),
            nn.Linear(dim * ffn_expansion, dim)
        )
        
        # 初始化权重
        self._init_weights()
        
    def _init_weights(self):
        # 初始化层归一化
        nn.init.constant_(self.norm1.weight, 1.0)
        nn.init.constant_(self.norm1.bias, 0.0)
        nn.init.constant_(self.norm2.weight, 1.0)
        nn.init.constant_(self.norm2.bias, 0.0)
        
        # 初始化FFN权重
        for layer in self.ffn:
            if isinstance(layer, nn.Linear):
                # 第一层使用Kaiming初始化
                if layer.in_features == self.norm1.weight.size(0):
                    nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
                # 第二层使用较小的Xavier初始化
                else:
                    nn.init.xavier_uniform_(layer.weight, gain=1/math.sqrt(2))
                
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, 0)

    def forward(self, query, key_value):
        # 稀疏交叉注意力
        attn_out = self.attn(query, key_value,key_value)
        x = self.norm1(query + attn_out)  # 残差连接 + 层归一化
        
        # 前馈网络
        ffn_out = self.ffn(x)
        out = self.norm2(x + ffn_out)  # 残差连接 + 层归一化
        
        return out

class SparseCrossAttention(nn.Module):
    def __init__(self, dim=384, num_heads=8, top_k=128):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.top_k = top_k
        
        # 确保维度可被头数整除
        assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"
        
        # Query投影层
        self.q_proj = nn.Linear(dim, dim, bias=False)
        # Key/Value投影层
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        
        # 输出投影层
        self.out_proj = nn.Linear(dim, dim)
        
        # 初始化权重
        self._init_weights()

    def _init_weights(self):
        # 使用Xavier均匀初始化投影层权重
        nn.init.xavier_uniform_(self.q_proj.weight)
        nn.init.xavier_uniform_(self.k_proj.weight)
        nn.init.xavier_uniform_(self.v_proj.weight)
        
        # 输出层使用更小的初始化范围
        nn.init.xavier_uniform_(self.out_proj.weight, gain=1/math.sqrt(2))
        if self.out_proj.bias is not None:
            nn.init.constant_(self.out_proj.bias, 0)

    def gather_topk_values(self,v, topk_indices):
        """
        Args:
            v: Tensor of shape [1, heads, len(dict), head_dim]
            topk_indices: Tensor of shape [batch, heads, 256, top_k]
        
        Returns:
            out: Tensor of shape [batch, heads, 256, top_k, head_dim]
        """
        batch, heads, seq_len, top_k = topk_indices.shape
        dict_size = v.size(2)  # 从v中获取字典大小
        head_dim = v.size(3)
        
        # 重塑v以匹配索引维度 [1, heads, len(dict), head_dim] -> [heads, len(dict), head_dim]
        v_reshaped = v.squeeze(0)  # 移除batch和seq维度
        
        # 创建head维度的偏移量 [heads,]
        offsets = torch.arange(heads, device=topk_indices.device) * dict_size
        
        # 将偏移量加到索引上 [batch, heads, 256, top_k]
        flat_indices = topk_indices + offsets.view(1, heads, 1, 1)
        
        # 重塑v为二维张量 [heads * len(dict), head_dim]
        v_flat = v_reshaped.reshape(-1, head_dim)
        
        # 展平索引 [batch * heads * 256 * top_k]
        flat_indices = flat_indices.contiguous().view(-1)
        
        # 收集结果 [batch * heads * 256 * top_k, head_dim]
        gathered = v_flat[flat_indices]
        
        # 重塑为最终形状 [batch, heads, 256, top_k, head_dim]
        return gathered.view(batch, heads, seq_len, top_k, head_dim)

    def forward(self, query, key_value):
        """
        Args:
            query: [batch, 256, dim]
            key_value: [1, 100000, dim] (所有batch共享)
        Returns:
            output: [batch, 256, dim]
        """
        batch_size = query.size(0)
        q_len = query.size(1)
        kv_len = key_value.size(1)
        
        # 1. 投影得到Q, K, V
        q = self.q_proj(query)  # [batch, 256, dim]
        k = self.k_proj(key_value)  # [1, 100000, dim]
        v = self.v_proj(key_value)  # [1, 100000, dim]
        
        # 2. 重整形为多头格式
        q = q.view(batch_size, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [batch, heads, 256, head_dim]
        k = k.view(1, kv_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [1, heads, 100000, head_dim]
        v = v.view(1, kv_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [1, heads, 100000, head_dim]
        
        # 3. 计算稀疏注意力分数
        # 计算每个query与所有key的初始分数
        attn_scores = torch.matmul(q, k.transpose(-2, -1))  # [batch, heads, 256, 100000]
        attn_scores = attn_scores / (self.head_dim ** 0.5)
        
        # 4. 稀疏化处理：选择每个query的top-k key
        # 找到每个query最大的k个分数及其索引
        topk_scores, topk_indices = torch.topk(
            attn_scores, 
            k=self.top_k, 
            dim=-1, 
            sorted=False
        )  # [batch, heads, 256, top_k]
        
        # 5. 从完整value中聚集top-k value (修复部分)
        # 扩展v以匹配batch_size
        # v_expanded = v.unsqueeze(2)  # [1, heads,1, 100000, head_dim]
        
        # 准备索引张量 - 修复维度不匹配问题
        # 创建与v_expanded形状匹配的索引张量
        # 首先扩展索引到head_dim维度
        # expanded_indices = topk_indices.unsqueeze(-1).expand(-1, -1, -1, -1, self.head_dim)
        
        # # 使用gather收集值 - 修正维度
        # topk_v = v_expanded.gather(
        #     dim=2, 
        #     index=expanded_indices
        # )  # [batch, heads, 256, top_k, head_dim]
        
        # expanded_indices = topk_indices.unsqueeze(-1)# [batch, heads, q_len, top_k,1]
        # out = torch.gather(v_expanded, dim=3, index=expanded_indices)
        # topk_v=out.squeeze(-2)
        # v_expanded = v_expanded.unsqueeze(2).expand(-1, -1, 1, -1, -1)  # 添加新维度 [1, heads, 1, len(dict), head_dim]
        # topk_v = v_expanded.gather(dim=3, index=expanded_indices)  # [batch, heads, q_len, top_k, head_dim]
        topk_v=self.gather_topk_values(v,topk_indices)

        # 6. 计算注意力权重和输出
        attn_weights = F.softmax(topk_scores, dim=-1)  # [batch, heads, 256, top_k]
        attn_weights = attn_weights.unsqueeze(-1)  # [batch, heads, 256, top_k, 1]
        
        # 注意力加权求和 (更高效的方式)
        output = (attn_weights * topk_v).sum(dim=3)  # [batch, heads, 256, head_dim]
        
        # 7. 重整形为原始格式
        output = output.permute(0, 2, 1, 3).contiguous()  # [batch, 256, heads, head_dim]
        output = output.view(batch_size, q_len, self.dim)  # [batch, 256, dim]
        
        # 8. 输出投影
        output = self.out_proj(output)
        
        return output

class SparseCrossAttentionTransformer(nn.Module):
    def __init__(self, dim=384, num_heads=8, top_k=128, ffn_expansion=4):
        super().__init__()
        self.attn = SparseCrossAttention(dim, num_heads, top_k)
        
        # 层归一化
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * ffn_expansion),
            nn.GELU(),
            nn.Linear(dim * ffn_expansion, dim)
        )
        
        # 初始化权重
        self._init_weights()
        
    def _init_weights(self):
        # 初始化层归一化
        nn.init.constant_(self.norm1.weight, 1.0)
        nn.init.constant_(self.norm1.bias, 0.0)
        nn.init.constant_(self.norm2.weight, 1.0)
        nn.init.constant_(self.norm2.bias, 0.0)
        
        # 初始化FFN权重
        for layer in self.ffn:
            if isinstance(layer, nn.Linear):
                # 第一层使用Kaiming初始化
                if layer.in_features == self.norm1.weight.size(0):
                    nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
                # 第二层使用较小的Xavier初始化
                else:
                    nn.init.xavier_uniform_(layer.weight, gain=1/math.sqrt(2))
                
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, 0)

    def forward(self, query, key_value):
        # 稀疏交叉注意力
        attn_out = self.attn(query, key_value)
        x = self.norm1(query + attn_out)  # 残差连接 + 层归一化
        
        # 前馈网络
        ffn_out = self.ffn(x)
        out = self.norm2(x + ffn_out)  # 残差连接 + 层归一化
        
        return out

class learn(nn.Module):
    def __init__(self, dim=384, num_heads=8, top_k=128):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.top_k = top_k
        
        # 确保维度可被头数整除
        assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"
        
        # Query投影层
        self.q_proj = nn.Linear(dim, dim, bias=False)
        # Key/Value投影层
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)

        # self.teacher_q_proj = nn.Linear(dim, dim, bias=False)
        # self.teacher_v_proj = nn.Linear(dim, dim, bias=False)
        
        # 输出投影层
        self.out_proj = nn.Linear(dim, dim)
        
        # 初始化权重
        self._init_weights()

    def _init_weights(self):
        # 使用Xavier均匀初始化投影层权重
        nn.init.xavier_uniform_(self.q_proj.weight)
        nn.init.xavier_uniform_(self.k_proj.weight)
        nn.init.xavier_uniform_(self.v_proj.weight)
        # nn.init.xavier_uniform_(self.teacher_q_proj.weight)
        # nn.init.xavier_uniform_(self.teacher_v_proj.weight)
        
        # 输出层使用更小的初始化范围
        nn.init.xavier_uniform_(self.out_proj.weight, gain=1/math.sqrt(2))
        if self.out_proj.bias is not None:
            nn.init.constant_(self.out_proj.bias, 0)

    def gather_topk_values(self,v, topk_indices):
        """
        Args:
            v: Tensor of shape [1, heads, len(dict), head_dim]
            topk_indices: Tensor of shape [batch, heads, 256, top_k]
        
        Returns:
            out: Tensor of shape [batch, heads, 256, top_k, head_dim]
        """
        batch, heads, seq_len, top_k = topk_indices.shape
        dict_size = v.size(2)  # 从v中获取字典大小
        head_dim = v.size(3)
        
        # 重塑v以匹配索引维度 [1, heads, len(dict), head_dim] -> [heads, len(dict), head_dim]
        v_reshaped = v.squeeze(0)  # 移除batch和seq维度
        
        # 创建head维度的偏移量 [heads,]
        offsets = torch.arange(heads, device=topk_indices.device) * dict_size
        
        # 将偏移量加到索引上 [batch, heads, 256, top_k]
        flat_indices = topk_indices + offsets.view(1, heads, 1, 1)
        
        # 重塑v为二维张量 [heads * len(dict), head_dim]
        v_flat = v_reshaped.reshape(-1, head_dim)
        
        # 展平索引 [batch * heads * 256 * top_k]
        flat_indices = flat_indices.contiguous().view(-1)
        
        # 收集结果 [batch * heads * 256 * top_k, head_dim]
        gathered = v_flat[flat_indices]
        
        # 重塑为最终形状 [batch, heads, 256, top_k, head_dim]
        return gathered.view(batch, heads, seq_len, top_k, head_dim)

    def update(self,v,topk_indices,teacher):
        """
        a: [batch, 256] 索引张量
        b: [batch, 256, 384] 特征张量
        c: [1000, 384] 字典张量
        """
        res=torch.zeros_like(v.squeeze(0))
        topk_indices=topk_indices[:,:,:,0]
        for i in range(topk_indices.shape[1]):
            # print(topk_indices[:,i,:,:].shape)
            a=topk_indices[:,i,:].reshape(topk_indices.shape[0],topk_indices.shape[2])
            b=teacher
            c=torch.zeros_like(v.squeeze(0))
            n_dict = c.size(0)  # 字典大小 (1000)
            dim = b.size(-1)    # 特征维度 (384)
            
            # 展平索引和特征
            a_flat = a.reshape(-1)               # [batch * 256]
            b_flat = b.view(-1, dim)          # [batch * 256, 384]
            
            # 初始化累加器和计数器 (使用float32保证精度)
            c_sum = torch.zeros((n_dict, dim), device=a.device, dtype=torch.float32)
            count = torch.zeros(n_dict, device=a.device, dtype=torch.float32)
            
            # 并行累加特征和计数
            c_sum.index_add_(0, a_flat, b_flat.float())  # 累加特征到对应位置
            count.index_add_(0, a_flat, torch.ones_like(a_flat, dtype=torch.float32))  # 累加计数
            
            # 计算需要更新的行索引
            mask = count > 0
            indices = mask.nonzero(as_tuple=False).squeeze(1)  # 非零计数的索引
            
            if indices.numel() > 0:  # 如果有需要更新的行
                # 计算平均值 = 总和 / 计数 (自动广播)
                avg_values = c_sum[indices] / count[indices].unsqueeze(1)
                
                # 将平均值加到字典c的对应行 (保持c的原始数据类型)
                c.index_add_(0, indices, avg_values.to(c.dtype))
            res=res+c/topk_indices.shape[1]
        return 0.99*v+0.01*res.unsqueeze(0)

    def forward(self, query,teacher, key_value):
        """
        Args:
            query: [batch, 256, dim]
            key_value: [1, 100000, dim] (所有batch共享)
        Returns:
            output: [batch, 256, dim]
        """
        batch_size = query.size(0)
        q_len = query.size(1)
        kv_len = key_value.size(1)
        
        # 1. 投影得到Q, K, V
        q = self.q_proj(query)  # [batch, 256, dim]
        # q_t = self.teacher_q_proj(teacher)  # [batch, 256, dim]
        k = self.k_proj(key_value)  # [1, 100000, dim]
        # v_t = self.teacher_v_proj(teacher)  # [batch, 256, dim]
        
        # 2. 重整形为多头格式
        q = q.view(batch_size, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [batch, heads, 256, head_dim]
        # q_t = q_t.view(batch_size, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [batch, heads, 256, head_dim]
        k = k.view(1, kv_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [1, heads, 100000, head_dim]
        # v = v.view(1, kv_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [1, heads, 100000, head_dim]
        # v_t = v_t.view(1, kv_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [1, heads, 100000, head_dim]
        
        # 3. 计算稀疏注意力分数
        # 计算每个query与所有key的初始分数
        attn_scores = torch.matmul(q, k.transpose(-2, -1))  # [batch, heads, 256, 100000]
        attn_scores = attn_scores / (self.head_dim ** 0.5)
        
        # 4. 稀疏化处理：选择每个query的top-k key
        # 找到每个query最大的k个分数及其索引
        topk_scores, topk_indices = torch.topk(
            attn_scores, 
            k=self.top_k, 
            dim=-1, 
            sorted=False
        )  # [batch, heads, 256, top_k]


        new_key_value=self.update(key_value,topk_indices,teacher)
        v = self.v_proj(new_key_value)  # [1, 100000, dim]
        v = v.view(1, kv_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [1, heads, 100000, head_dim]
        topk_v=self.gather_topk_values(v,topk_indices)

        # 6. 计算注意力权重和输出
        attn_weights = F.softmax(topk_scores, dim=-1)  # [batch, heads, 256, top_k]
        attn_weights = attn_weights.unsqueeze(-1)  # [batch, heads, 256, top_k, 1]
        
        # 注意力加权求和 (更高效的方式)
        output = (attn_weights * topk_v).sum(dim=3)  # [batch, heads, 256, head_dim]
        
        # 7. 重整形为原始格式
        output = output.permute(0, 2, 1, 3).contiguous()  # [batch, 256, heads, head_dim]
        output = output.view(batch_size, q_len, self.dim)  # [batch, 256, dim]
        
        # 8. 输出投影
        output = self.out_proj(output)
        
        return output,new_key_value.detach()

class Brain(nn.Module):
    def __init__(self, dim=384, dict_size=5000, num_heads=8, top_k=128, ffn_expansion=4):
        super().__init__()
        self.attn = learn(dim, num_heads, top_k)
        self.knowledge=nn.Embedding(dict_size,dim)
        self.knowledge.weight.requires_grad=False
        
        # 层归一化
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * ffn_expansion),
            nn.GELU(),
            nn.Linear(dim * ffn_expansion, dim)
        )
        
        # 初始化权重
        self._init_weights()
        
    def _init_weights(self):
        # 初始化层归一化
        nn.init.constant_(self.norm1.weight, 1.0)
        nn.init.constant_(self.norm1.bias, 0.0)
        nn.init.constant_(self.norm2.weight, 1.0)
        nn.init.constant_(self.norm2.bias, 0.0)
        
        # 初始化FFN权重
        for layer in self.ffn:
            if isinstance(layer, nn.Linear):
                # 第一层使用Kaiming初始化
                if layer.in_features == self.norm1.weight.size(0):
                    nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
                # 第二层使用较小的Xavier初始化
                else:
                    nn.init.xavier_uniform_(layer.weight, gain=1/math.sqrt(2))
                
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, 0)

    def forward(self, query, teacher):
        # 稀疏交叉注意力
        attn_out,new_knowledge = self.attn(query, teacher,self.knowledge.weight.data.unsqueeze(0))
        self.knowledge.weight.data=new_knowledge.squeeze(0).detach()
        x = self.norm1(query + attn_out)  # 残差连接 + 层归一化
        # x=query
        
        # 前馈网络
        ffn_out = self.ffn(x)
        out = self.norm2(x + ffn_out)  # 残差连接 + 层归一化
        
        return out

class OminiSiTBlock(nn.Module):
    """
    A SiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = CrossAttention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c, aux_feat):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)

        # 使用主特征作为查询向量 Q，辅助特征作为键值向量 K 和 V
        x = x + gate_msa.unsqueeze(1) * self.attn(
            modulate(self.norm1(x), shift_msa, scale_msa),
            aux_feat,#modulate(aux_feat, shift_msa, scale_msa),
            aux_feat,#modulate(aux_feat, shift_msa, scale_msa)
        )

        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        
        # ========== 编码器部分 ==========
        # 降维层：384 -> 24
        self.encoder_linear = nn.Linear(384, 24)
        
        # 序列压缩层：256 -> 16
        self.encoder_conv = nn.Conv1d(
            in_channels=24,  # 输入特征维度
            out_channels=24,  # 输出特征维度（保持不变）
            kernel_size=16,
            stride=16,
            padding=0
        )
        
        # ========== 解码器部分 ==========
        # 序列恢复层：16 -> 256
        self.decoder_deconv = nn.ConvTranspose1d(
            in_channels=24,  # 输入特征维度
            out_channels=24,  # 输出特征维度（保持不变）
            kernel_size=16,
            stride=16,
            padding=0
        )
        
        # 升维层：24 -> 384
        self.decoder_linear = nn.Linear(24, 384)

        self.norm = nn.LayerNorm(24, elementwise_affine=False, eps=1e-6)
        
        # 初始化模型参数
        self._initialize_weights()
        
    def _initialize_weights(self):
        """初始化所有层参数"""
        # 初始化线性层
        init.xavier_uniform_(self.encoder_linear.weight)
        init.xavier_uniform_(self.decoder_linear.weight)
        
        # 初始化卷积层和转置卷积层
        init.kaiming_normal_(self.encoder_conv.weight, mode='fan_out', nonlinearity='relu')
        init.kaiming_normal_(self.decoder_deconv.weight, mode='fan_out', nonlinearity='relu')
        
        # 初始化偏置项为零
        init.zeros_(self.encoder_linear.bias)
        init.zeros_(self.decoder_linear.bias)
        init.zeros_(self.encoder_conv.bias)
        init.zeros_(self.decoder_deconv.bias)
        
    def forward(self, x):
        """
        前向传播过程
        输入: [N, 256, 384]
        输出: (解压结果, 压缩特征)
        """
        # ===== 编码器 =====
        # 1. 降维操作: [N, 256, 384] -> [N, 256, 24]
        x = self.encoder_linear(x)
        
        # 2. 维度置换: [N, 256, 24] -> [N, 24, 256]
        x = x.permute(0, 2, 1)
        
        # 3. 序列压缩: [N, 24, 256] -> [N, 24, 16]
        x = self.encoder_conv(x)
        
        # 4. 压缩表示: [N, 24, 16] -> [N, 16, 24]
        compressed = x.permute(0, 2, 1)

        compressed=self.norm(compressed)
        
        # ===== 解码器 =====
        # 1. 维度置换: [N, 16, 24] -> [N, 24, 16]
        # x = compressed.permute(0, 2, 1)
        
        # # 2. 序列恢复: [N, 24, 16] -> [N, 24, 256]
        # x = self.decoder_deconv(x)
        
        # # 3. 维度置换: [N, 24, 256] -> [N, 256, 24]
        # x = x.permute(0, 2, 1)
        
        # # 4. 升维操作: [N, 256, 24] -> [N, 256, 384]
        # x = self.decoder_linear(x)
        
        return x,compressed
    
    def decode(self, compressed):
        """
        前向传播过程
        输入: [N, 256, 384]
        输出: (解压结果, 压缩特征)
        """        
        # ===== 解码器 =====
        # 1. 维度置换: [N, 16, 24] -> [N, 24, 16]
        x = compressed.permute(0, 2, 1)
        
        # 2. 序列恢复: [N, 24, 16] -> [N, 24, 256]
        x = self.decoder_deconv(x)
        
        # 3. 维度置换: [N, 24, 256] -> [N, 256, 24]
        x = x.permute(0, 2, 1)
        
        # 4. 升维操作: [N, 256, 24] -> [N, 256, 384]
        x = self.decoder_linear(x)
        
        return x

class GradientBridge(torch.nn.Module):
    def __init__(self, module,kwargs):
        super().__init__()
        # 创建浅拷贝
        self.module = type(module)(**kwargs)
        self.module.__dict__ = module.__dict__.copy()
        self.module._parameters = module._parameters.copy()
        self.module._buffers = module._buffers.copy()
        self.module._modules = module._modules.copy()
    
    def forward(self, *args, **kwargs):
        with torch.no_grad():
            with torch.enable_grad():
                return self.module(*args, **kwargs)

class VectorQueues(nn.Module):
    def __init__(self, num_queues=1000, queue_size=1000, vector_dim=384):
        super(VectorQueues, self).__init__()
        """
        初始化向量队列系统
        
        参数:
            num_queues: 队列数量 (默认1000)
            queue_size: 每个队列的容量 (默认1000)
            vector_dim: 向量维度 (默认384)
        """
        self.num_queues = num_queues
        self.queue_size = queue_size
        self.vector_dim = vector_dim
        
        # 使用张量存储所有队列数据 [1000, 1000, 384]
        # self.queues = nn.Parameter(torch.zeros(1000,1000,384), requires_grad=False)#################################
        self.register_buffer('queues', torch.zeros((num_queues, queue_size, vector_dim)))
        # self.queues = torch.zeros((num_queues, queue_size, vector_dim))
        
        # 存储每个队列的当前指针位置 [1000]
        # self.queue_pointers = torch.zeros(num_queues, dtype=torch.long)
        self.register_buffer('queue_pointers', torch.zeros((num_queues),dtype=torch.long))
        
        # 队列状态统计
        # self.queue_counts = torch.zeros(num_queues, dtype=torch.long)
        self.register_buffer('queue_counts', torch.zeros((num_queues),dtype=torch.long))
        
        # 设备管理
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # self.queues = self.queues.to(self.device)
        self.queue_pointers = self.queue_pointers.to(self.device)
        self.queue_counts = self.queue_counts.to(self.device)

        self.rng = torch.Generator()
    
    def add_vectors(self, labels, vectors):
        """
        将向量添加到对应标签的队列中
        
        参数:
            labels: 标签张量 [batch_size]
            vectors: 向量张量 [batch_size, vector_dim]
        """
        # 确保输入在正确设备上
        labels = labels.to(self.device)
        vectors = vectors.to(self.device)
        
        # 验证输入形状
        if labels.dim() != 1:
            raise ValueError(f"labels 应为1维张量, 实际维度 {labels.dim()}")
        if vectors.size(0) != labels.size(0):
            raise ValueError(f"向量数量 ({vectors.size(0)}) 与标签数量 ({labels.size(0)}) 不匹配")
        if vectors.size(1) != self.vector_dim:
            raise ValueError(f"向量维度 ({vectors.size(1)}) 与队列维度 ({self.vector_dim}) 不匹配")
        
        # 处理每个向量
        for i in range(labels.size(0)):
            label = labels[i]
            vector = vectors[i]
            
            # 验证标签范围
            if label < 0 or label >= self.num_queues:
                raise ValueError(f"标签 {label} 超出范围 [0, {self.num_queues-1}]")
            
            # 获取当前指针位置
            ptr = self.queue_pointers[label]
            
            # 添加向量到队列
            self.queues[label, ptr] = vector
            
            # 更新指针位置 (环形队列)
            self.queue_pointers[label] = (ptr + 1) % self.queue_size
            
            # 更新计数 (不超过队列容量)
            if self.queue_counts[label] < self.queue_size:
                self.queue_counts[label] += 1
    
    # def get_queue(self, label):
    #     """
    #     获取指定标签的队列内容
        
    #     参数:
    #         label: 队列标签 (0-999)
        
    #     返回:
    #         队列内容 [min(queue_size, count), vector_dim]
    #     """
    #     if label < 0 or label >= self.num_queues:
    #         raise ValueError(f"标签 {label} 超出范围 [0, {self.num_queues-1}]")
        
    #     count = self.queue_counts[label]
    #     ptr = self.queue_pointers[label]
        
    #     # 如果队列未满
    #     if count < self.queue_size:
    #         return self.queues[label, :count].clone()
        
    #     # 如果队列已满 (环形队列)
    #     # 第一部分: 从当前指针到队列末尾
    #     part1 = self.queues[label, ptr:]
    #     # 第二部分: 从队列开头到当前指针
    #     part2 = self.queues[label, :ptr]
    #     # 合并两部分
    #     return torch.cat((part1, part2), dim=0)
    
    # def get_full_queue(self, label):
    #     """
    #     获取整个队列 (包括空位)
    #     """
    #     if label < 0 or label >= self.num_queues:
    #         raise ValueError(f"标签 {label} 超出范围 [0, {self.num_queues-1}]")
        
    #     return self.queues[label].clone()
    
    # def get_queue_stats(self):
    #     """
    #     获取队列统计信息
        
    #     返回:
    #         包含每个队列当前元素数量的张量
    #     """
    #     return self.queue_counts.clone()
    
    # def reset_queue(self, label):
    #     """
    #     重置指定队列
    #     """
    #     if label < 0 or label >= self.num_queues:
    #         raise ValueError(f"标签 {label} 超出范围 [0, {self.num_queues-1}]")
        
    #     self.queues[label].zero_()
    #     self.queue_pointers[label] = 0
    #     self.queue_counts[label] = 0
    
    # def reset_all(self):
    #     """
    #     重置所有队列
    #     """
    #     self.queues.zero_()
    #     self.queue_pointers.zero_()
    #     self.queue_counts.zero_()
    
    def to(self, device):
        """
        移动所有数据到指定设备
        """
        self.device = device
        self.queues = self.queues.to(device)
        self.queue_pointers = self.queue_pointers.to(device)
        self.queue_counts = self.queue_counts.to(device)
        return self

    def select(self,label):
        # self.rng.manual_seed(torch.randint(1000000, (1,)).item())
        ptr=(self.queue_pointers[label]+999)%1000
        random_indices = torch.randint(
            low=0, 
            high=1000, 
            size=(label.shape[0],), 
            # device=self.device,
            generator=self.rng
        )
        
        # 使用高级索引选择向量
        selected = self.queues[label, random_indices]
        
        return selected

class DictionaryConditioner(nn.Module):
    def __init__(self, dim, dict_size, momentum=0.99):
        """
        dim: 特征维度
        dict_size: 字典条目数
        momentum: 字典更新动量
        """
        super().__init__()
        self.dim = dim
        self.dict_size = dict_size
        self.momentum = momentum
        
        # 初始化可学习字典
        self.dictionary = nn.Parameter(torch.zeros(dict_size,dim), requires_grad=False)
        # self.register_buffer("dictionary", torch.randn(dict_size, dim))
        # self.dictionary = F.normalize(self.dictionary, dim=-1)  # 单位化
        
        # 统计计数器 (用于冷启动)
        self.register_buffer("usage_count", torch.zeros(dict_size))
        self.step=0
        self.warmup_steps=2000

    def query(self, query_vec,dict_grad):
        """用查询向量检索字典"""
        # 单位化查询向量 (余弦相似度)
        query_vec = F.normalize(query_vec, dim=-1)  # [b, token_len, dim]
        # 计算余弦相似度
        sim_matrix = torch.einsum('btd, kd -> btk', query_vec, dict_grad)
        # 软检索 (温度系数0.1)
        weights = F.softmax(sim_matrix / 0.1, dim=-1)
        # 加权组合字典项
        retrieved = torch.einsum('btk, kd -> btd', weights, dict_grad)
        return retrieved

    def update(self, key_vec):
        """用干净图像特征更新字典"""
        if self.step < self.warmup_steps:
        # 预热期使用更大学习率
            momentum = min(0.5, 1 - (self.step / self.warmup_steps))
            momentum=max(momentum,self.momentum)
        else:
            momentum = self.momentum

        key_vec = F.normalize(key_vec, dim=-1)  # [b, token_len, dim]
        flat_keys = key_vec.reshape(-1, self.dim)  # [n, dim]

        nan_mask = torch.isnan(key_vec)
        has_nan = nan_mask.any().item()

        if has_nan:
            print("key_vec 中存在 NaN 值")
            exit(0)
        
        # 1. 计算与字典的相似度并找到最近索引
        sim_matrix = torch.einsum('nd, kd -> nk', flat_keys, self.dictionary)
        closest_idx = torch.argmax(sim_matrix, dim=-1)  # [n]
        print(closest_idx.shape)
        print(closest_idx)
        
        # 2. 创建簇分配矩阵 (one-hot编码)
        cluster_mask = F.one_hot(closest_idx, num_classes=self.dict_size)  # [n, dict_size]
        
        # 3. 计算每个簇的计数
        cluster_counts = cluster_mask.sum(dim=0)  # [dict_size]
        
        # 4. 计算每个簇的总和 (避免除以零)
        cluster_sums = torch.einsum('nk, nd -> kd', cluster_mask.float(), flat_keys)
        
        # 5. 计算簇均值 (对于空簇使用原字典值)momentum
        cluster_means = torch.where(
            cluster_counts.unsqueeze(1) > 0,
            cluster_sums / cluster_counts.unsqueeze(1),
            self.dictionary.detach()
        )

        nan_mask = torch.isnan(cluster_means)
        has_nan = nan_mask.any().item()

        if has_nan:
            print(cluster_counts,"cluster_means 中存在 NaN 值")
            exit(0)

        
        # 6. EMA更新 (并行处理所有字典项)
        new_dictionary = (
            momentum * self.dictionary.detach() + 
            (1 - momentum) * cluster_means
        )

        
        # 7. 更新字典和计数器
        self.dictionary.data=new_dictionary
        self.usage_count.add_(cluster_counts)
        self.step+=1

        return new_dictionary#用于反向传播给key_vec
        
class VectorQuantizer(nn.Module):
    def __init__(self, codebook_size, latent_dim, beta=0.25):
        super().__init__()
        self.codebook = nn.Embedding(codebook_size, latent_dim)
        self.beta = beta
        self.codebook_size = codebook_size
        self.latent_dim = latent_dim

    def forward(self, z):
        """
        输入z形状: [batch, token_len, dim] = [B, T, D]
        """
        # 保存原始形状用于后续恢复
        original_shape = z.shape
        B, T, D = original_shape
        
        # 重塑为二维矩阵 [B*T, D]
        z_flat = z.reshape(-1, D)  # [B*T, D]
        
        # 计算z与codebook的L2距离
        distances = (
            torch.sum(z_flat**2, dim=1, keepdim=True)  # [B*T, 1]
            + torch.sum(self.codebook.weight**2, dim=1)  # [codebook_size]
            - 2 * torch.matmul(z_flat, self.codebook.weight.t())  # [B*T, codebook_size]
        )
        
        # 取最近邻编码
        min_encoding_indices = torch.argmin(distances, dim=1)  # [B*T]
        z_q_flat = self.codebook(min_encoding_indices)  # [B*T, D]
        
        # 梯度复制技巧：允许梯度绕过量化直通编码器
        z_q_flat = z_flat + (z_q_flat - z_flat).detach()  # 直通估计器
        
        # 恢复原始形状 [B, T, D]
        z_q = z_q_flat.view(original_shape)
        
        # 计算量化损失
        commit_loss = self.beta * torch.mean((z_q_flat.detach() - z_flat)**2)
        codebook_loss = torch.mean((z_q_flat - z_flat.detach())**2)
        
        return self.codebook.weight, commit_loss + codebook_loss
    
    def query(self, query_vec):
        """用查询向量检索字典"""
        # 单位化查询向量 (余弦相似度)
        # dt=self.codebook.weight#.detach()
        query_vec = F.normalize(query_vec, dim=-1)  # [b, token_len, dim]
        # 计算余弦相似度
        sim_matrix = torch.einsum('btd, kd -> btk', query_vec, self.codebook.weight)
        # 软检索 (温度系数0.1)
        weights = F.softmax(sim_matrix / 0.1, dim=-1)
        # 加权组合字典项
        retrieved = torch.einsum('btk, kd -> btd', weights, self.codebook.weight)
        return retrieved

class FinalLayer(nn.Module):
    """
    The final layer of SiT.
    """
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        # print("patch_size * patch_size * out_channels",patch_size * patch_size * out_channels)
        # exit(0)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x



class SiT(nn.Module):
    """
    Diffusion model with a Transformer backbone.
    """
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4.0,
        class_dropout_prob=0.1,
        num_classes=1000,
        learn_sigma=True,
        enc_layer=-1
    ):
        super().__init__()
        self.learn_sigma = learn_sigma
        self.in_channels = in_channels
        self.out_channels = in_channels * 2 if learn_sigma else in_channels
        self.patch_size = patch_size
        self.num_heads = num_heads

        self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
        self.t_embedder = TimestepEmbedder(hidden_size)
        self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
        num_patches = self.x_embedder.num_patches
        # Will use fixed sin-cos embedding:
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)

        self.x_query_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Linear(hidden_size * 4, hidden_size)
        )
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Linear(hidden_size * 4, hidden_size)
        )
        self.norm = nn.LayerNorm(hidden_size)
        self.enc_layer=enc_layer
        self.depth=depth
        # self.query_blocks = nn.ModuleList([
        #     SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(4)
        # ])
        # print(self.answer_weight)
        # self.dt = VectorQueues()#nn.Parameter(torch.zeros(1000,256,384), requires_grad=False)#################################VectorQueues()#

        self.blocks = nn.ModuleList([
            SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
        ])
        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
        self.initialize_weights()

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

        # Initialize (and freeze) pos_embed by sin-cos embedding:
        pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

        # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
        w = self.x_embedder.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        nn.init.constant_(self.x_embedder.proj.bias, 0)

        # Initialize label embedding table:
        nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)

        # Initialize timestep embedding MLP:
        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers in SiT blocks:
        for block in self.blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

        # for block in self.aux_blocks:######################33
        #     nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
        #     nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
        w = self.x_query_embedder.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        nn.init.constant_(self.x_query_embedder.proj.bias, 0)
        for layer in self.ffn:
            if isinstance(layer, nn.Linear):
                # 第一层使用Kaiming初始化
                if layer.in_features == self.norm.weight.size(0):
                    nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
                # 第二层使用较小的Xavier初始化
                else:
                    nn.init.xavier_uniform_(layer.weight, gain=1/math.sqrt(2))
                
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, 0)
        # for layer in self.mlp:
        #     if isinstance(layer, nn.Linear):
        #         # 第一层使用Kaiming初始化
        #         if layer.in_features == self.norm.weight.size(0):
        #             nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
        #         # 第二层使用较小的Xavier初始化
        #         else:
        #             nn.init.xavier_uniform_(layer.weight, gain=1/math.sqrt(2))
                
        #         if layer.bias is not None:
        #             nn.init.constant_(layer.bias, 0)
        # for layer in self.mlp:
        #     if isinstance(layer, nn.Linear):
        #         # 第一层使用Kaiming初始化
        #         if layer.in_features == self.norm.weight.size(0):
        #             nn.init.constant_(layer.weight, 0)
        #         # 第二层使用较小的Xavier初始化
        #         else:
        #             nn.init.constant_(layer.weight, 0)
                
        #         if layer.bias is not None:
        #             nn.init.constant_(layer.bias, 0)
        ########################################

        # Zero-out output layers:
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def unpatchify(self, x):
        """
        x: (N, T, patch_size**2 * C)
        imgs: (N, H, W, C)
        """
        c = self.out_channels
        p = self.x_embedder.patch_size[0]
        h = w = int(x.shape[1] ** 0.5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
        return imgs

    def forward_ori(self, x, t, y):
        """
        Forward pass of SiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
        t = self.t_embedder(t)                   # (N, D)
        y = self.y_embedder(y, self.training)    # (N, D)
        c = t + y
        cnt=1                              # (N, D)
        for block in self.blocks:
            x = block(x, c)                      # (N, T, D)
            if cnt==8:
                z_contrast=x
            cnt+=1
        x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)
        x = self.unpatchify(x)                   # (N, out_channels, H, W)
        if self.learn_sigma:
            x, _ = x.chunk(2, dim=1)
        return x#,x,x#,z_contrast
    
    # def forward(self, x, x_pure, t, y):
    #     """
    #     Forward pass of SiT.
    #     x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
    #     t: (N,) tensor of diffusion timesteps
    #     y: (N,) tensor of class labels
    #     """
    #     x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
    #     x_pure = self.x_pure_embedder(x_pure) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
    #     t = self.t_embedder(t)                   # (N, D)
    #     label=y.to(dtype=torch.int64)
    #     y = self.y_embedder(y, self.training)    # (N, D)
    #     c = t + y                                # (N, D)

    #     for i in range(0,8):#self.blocks:
    #         x_pure = self.blocks[i](x_pure, y)                      # (N, T, D)

    #     # condition=0.999*self.dt[label]+0.001*x_pure
    #     # self.dt[label]=0.999*self.dt[label]+0.001*x_pure.detach()

    #     all_x_pure=concat_all_gather(x_pure)
    #     all_label=concat_all_gather(label)
    #     # 对 a 按照 label 进行分组平均
    #     with torch.autograd.set_detect_anomaly(True):
    #         unique_x_pure = scatter_mean(all_x_pure, all_label, dim=0)
    #     # 获取去重后的标签
    #     unique_labels = torch.unique(all_label)
    #     unique_x_pure_filled=unique_x_pure[unique_labels]
    #     self.cnt+=1
    #     update_rate = min(0.1, 0.001 * (self.cnt/5000 + 1))  # 随训练增加更新率
    #     with torch.no_grad():
    #         self.dt[unique_labels]=(1-update_rate)*self.dt[unique_labels]+update_rate*unique_x_pure_filled.detach()
    #     condition=self.dt[label]-update_rate*unique_x_pure[label].detach()+update_rate*unique_x_pure[label]############################################

    #     condition_norm=self.condition_norm(condition)
    #     c_norm=self.c_norm(c)
    #     # combined = torch.cat([
    #     #     condition_norm, 
    #     #     c_norm.unsqueeze(1).expand(-1, condition_norm.size(1), -1)
    #     # ], dim=-1)
    #     # combined = self.cond(combined)  # 通过线性变换融合
    #     combined=condition_norm+c_norm.unsqueeze(1)

    #     # for i in range(4,8):#self.blocks:
    #     #     combined = self.blocks[i](combined, c)  
        
    #     for i in range(8,12):#self.blocks:
    #         x = self.blocks[i](x, combined,True) 
        
    #     x = self.final_layer(x, combined)               # (N, T, patch_size ** 2 * out_channels)
    #     x = self.unpatchify(x)                   # (N, out_channels, H, W)
    #     if self.learn_sigma:
    #         x, _ = x.chunk(2, dim=1)
    #     return x,x

    # def forward(self, x, t, y):
    #     """
    #     Forward pass of SiT.
    #     x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
    #     t: (N,) tensor of diffusion timesteps
    #     y: (N,) tensor of class labels
    #     """
    #     x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
    #     # x_pure = self.x_pure_embedder(x_pure) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
    #     t = self.t_embedder(t)                   # (N, D)
    #     label=y.to(dtype=torch.int64)
    #     y = self.y_embedder(y, self.training)    # (N, D)
    #     c = t + y                                # (N, D)

    #     # for i in range(0,8):#self.blocks:
    #     #     x_pure = self.aux_blocks[i](x_pure, y)                      # (N, T, D)
        
    #     # all_x_pure=concat_all_gather(x_pure)
    #     # all_label=concat_all_gather(label)
    #     # # 对 a 按照 label 进行分组平均
    #     # with torch.autograd.set_detect_anomaly(True):
    #     #     unique_x_pure = scatter_mean(all_x_pure, all_label, dim=0)
    #     # # 获取去重后的标签
    #     # unique_labels = torch.unique(all_label)
    #     # unique_x_pure_filled=unique_x_pure[unique_labels]
    #     # with torch.no_grad():
    #     #     self.dt[unique_labels]=0.999*self.dt[unique_labels]+0.001*unique_x_pure_filled.detach()

    #     # condition=self.dt[label]-0.001*unique_x_pure[label].detach()+0.001*unique_x_pure[label]############################################
    #     condition=self.dt[label]

    #     combined = self.condition_norm(condition) + self.c_norm(c).unsqueeze(1)
        
    #     for i in range(0,12):#self.blocks:
    #         x = self.blocks[i](x, combined,True) 
        
    #     x = self.final_layer(x,combined)               # (N, T, patch_size ** 2 * out_channels)
    #     x = self.unpatchify(x)                   # (N, out_channels, H, W)
    #     if self.learn_sigma:
    #         x, _ = x.chunk(2, dim=1)
    #     return x

    def forward_serial(self, x, x_aux, t,taux, y,fl=False):############train
        """
        Forward pass of SiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        x_query = self.x_query_embedder(x) + self.pos_embed 
        x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
        # x_aux = self.x_pure_embedder(x_aux) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
        x_aux_query = self.x_query_embedder(x_aux) + self.pos_embed 
        x_aux = self.x_embedder(x_aux) + self.pos_embed 
        t = self.t_embedder(t)                   # (N, D)
        taux = self.t_embedder(taux)  
        y = self.y_embedder(y, self.training)    # (N, D)
        c = t + y                                # (N, D)
        c_aux = taux + y                                # (N, D)
        
        for i in range(0,4):#self.blocks:
            x_query = self.query_blocks[i](x_query, c) 
            x_aux_query = self.query_blocks[i](x_aux_query, c_aux)                      # (N, T, D)  

        ffn_out = self.ffn(x_query)
        x_query = self.norm(x_query + ffn_out)  # 残差连接 + 层归一化
        ffn_out = self.ffn(x_aux_query)
        x_aux_query = self.norm(x_aux_query + ffn_out)  # 残差连接 + 层归一化

        for i in range(0,8):#self.blocks:
            x = self.blocks[i](x, x_query+c.unsqueeze(1),True)
            x_aux = self.blocks[i](x_aux, x_aux_query+c_aux.unsqueeze(1),True)
        
        x = self.final_layer(x,x_query+c.unsqueeze(1))               # (N, T, patch_size ** 2 * out_channels)
        x = self.unpatchify(x)                   # (N, out_channels, H, W)
        if self.learn_sigma:
            x, _ = x.chunk(2, dim=1)
        
        x_aux = self.final_layer(x_aux,x_aux_query+c_aux.unsqueeze(1))               # (N, T, patch_size ** 2 * out_channels)
        x_aux = self.unpatchify(x_aux)                   # (N, out_channels, H, W)
        if self.learn_sigma:
            x_aux, _ = x_aux.chunk(2, dim=1)
        # if self.step<len(self.answer_weight):
        #     return x,x,torch.mean((z_noise-z_pure)**2)
        return x,x_aux,-F.cosine_similarity(self.mlp(x_query),self.mlp(x_aux_query),dim=2).mean()#torch.mean((z_noise-z_pure.detach())**2)
    
    def forward(self, x, x_aux, t, taux, y):
        x_query = self.x_query_embedder(x) + self.pos_embed 
        x = self.x_embedder(x) + self.pos_embed
        x_aux_query = self.x_query_embedder(x_aux) + self.pos_embed 
        x_aux = self.x_embedder(x_aux) + self.pos_embed 
        t = self.t_embedder(t)
        taux = self.t_embedder(taux)  
        y = self.y_embedder(y, self.training)
        c = t + y
        c_aux = taux + y
        
        # 并行处理query_blocks
        combined_query = torch.cat([x_query, x_aux_query], dim=0)
        combined_c = torch.cat([c, c_aux], dim=0)
        for i in range(0,self.enc_layer):
            combined_query = self.blocks[i](combined_query, combined_c)
        x_query, x_aux_query = torch.split(combined_query, [x_query.size(0), x_aux_query.size(0)], dim=0)
        
        # 并行处理FFN
        ffn_out = self.ffn(combined_query)
        combined_query = self.norm(combined_query + ffn_out)
        x_query, x_aux_query = torch.split(combined_query, [x_query.size(0), x_aux_query.size(0)], dim=0)
        
        # 准备blocks的条件
        cond1 = x_query + c.unsqueeze(1)
        cond2 = x_aux_query + c_aux.unsqueeze(1)
        combined_cond = torch.cat([cond1, cond2], dim=0)
        
        # 并行处理blocks
        combined_x = torch.cat([x, x_aux], dim=0)
        for i in range(self.enc_layer,self.depth):
            combined_x = self.blocks[i](combined_x, combined_cond, True)
        x, x_aux = torch.split(combined_x, [x.size(0), x_aux.size(0)], dim=0)
        
        x = self.final_layer(x, x_query + c.unsqueeze(1))
        x = self.unpatchify(x)
        if self.learn_sigma:
            x, _ = x.chunk(2, dim=1)
        
        x_aux = self.final_layer(x_aux, x_aux_query + c_aux.unsqueeze(1))
        x_aux = self.unpatchify(x_aux)
        if self.learn_sigma:
            x_aux, _ = x_aux.chunk(2, dim=1)
            
        return x, x_aux, -F.cosine_similarity(self.mlp(x_query), self.mlp(x_aux_query), dim=2).mean()

    def forward1(self, x, t, y):############eval
        """
        Forward pass of SiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        x_query = self.x_query_embedder(x) + self.pos_embed 
        x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
        # x_pure = self.x_pure_embedder(x_pure) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
        t = self.t_embedder(t)                   # (N, D)
        y = self.y_embedder(y, self.training)    # (N, D)
        c = t + y                                # (N, D)
        
        for i in range(0,4):#self.blocks:
            x_query = self.query_blocks[i](x_query, c) 
            # x_pure = self.aux_blocks[i](x_pure, c)                      # (N, T, D)  
        
        ffn_out = self.ffn(x_query)
        x_query = self.norm(x_query + ffn_out)  # 残差连接 + 层归一化

        # _,z_noise=self.ff_query(x_query)
        # _,z_pure=self.ff(x_pure)

        # z_pure=torch.zeros_like(z_noise)

        # answer_weight=0
        # if self.step<len(self.answer_weight):
        #     answer_weight=self.answer_weight[self.step]
        # answer_weight=1/60
        # z=answer_weight*z_pure+(1-answer_weight)*z_noise
        # self.step+=1
        # rank = dist.get_rank()
        # if rank==0:
        #     print("z_pure",torch.var((z_pure)))
        #     print("z_noise",torch.var((z_noise)))
        #     print("answer_weight*z_pure",torch.var((answer_weight*z_pure)))
        #     print("(1-answer_weight)*z_noise",torch.var(((1-answer_weight)*z_noise)))
        # exit(0)
        
        # x_feature=x_query+self.ff_query.decode(z)

        for i in range(0,8):#self.blocks:
            x = self.blocks[i](x, x_query+c.unsqueeze(1),True)
        
        x = self.final_layer(x,x_query+c.unsqueeze(1))               # (N, T, patch_size ** 2 * out_channels)
        x = self.unpatchify(x)                   # (N, out_channels, H, W)
        if self.learn_sigma:
            x, _ = x.chunk(2, dim=1)
        return x


    def forward_with_cfg(self, x, t, y, cfg_scale):
        """
        Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
        """
        # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
        half = x[: len(x) // 2]
        combined = torch.cat([half, half], dim=0)
        model_out,_ = self.forward(combined, t, y)
        # print("model_out",model_out)
        # For exact reproducibility reasons, we apply classifier-free guidance on only
        # three channels by default. The standard approach to cfg applies it to all channels.
        # This can be done by uncommenting the following line and commenting-out the line following that.
        # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
        eps, rest = model_out[:, :3], model_out[:, 3:]
        cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
        half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)
        return torch.cat([eps, rest], dim=1)

# class SiT(nn.Module):
#     """
#     Diffusion model with a Transformer backbone.
#     """
#     def __init__(
#         self,
#         input_size=32,
#         patch_size=2,
#         in_channels=4,
#         hidden_size=1152,
#         depth=28,
#         num_heads=16,
#         mlp_ratio=4.0,
#         class_dropout_prob=0.1,
#         num_classes=1000,
#         learn_sigma=True,
#     ):
#         super().__init__()
#         self.learn_sigma = learn_sigma
#         self.in_channels = in_channels
#         self.out_channels = in_channels * 2 if learn_sigma else in_channels
#         self.patch_size = patch_size
#         self.num_heads = num_heads

#         self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
#         self.t_embedder = TimestepEmbedder(hidden_size)
#         self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
#         num_patches = self.x_embedder.num_patches
#         # Will use fixed sin-cos embedding:
#         self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)

#         self.x_pure_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)##############
#         self.x_query_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
#         self.aux_blocks = nn.ModuleList([
#             SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(4)
#         ])
#         # self.dt=VectorQuantizer(codebook_size=5000,latent_dim=384)
#         # self.condition_norm = nn.LayerNorm(hidden_size)
#         # self.c_norm = nn.LayerNorm(hidden_size)
#         # self.cond = nn.Linear(2*hidden_size, hidden_size, bias=True)
#         # self.ff= AutoEncoder()
#         # self.query_dict=SparseCrossAttentionTransformer()
#         self.query_teacher=CrossAttentionTransformer()
#         self.brain=Brain()
#         self.dt = nn.Parameter(torch.zeros(5000,384), requires_grad=True)#################################VectorQueues()#
#         # self.mlp = nn.Sequential(
#         #     nn.Linear(in_features=768, out_features=768*4),
#         #     nn.ReLU(),
#         #     nn.Linear(in_features=768*4, out_features=384),
#         # )
#         # self.dt = nn.Embedding(16384, 768)
#         # self.dt.weight.requires_grad = True

#         self.blocks = nn.ModuleList([
#             OminiSiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) if i<0 else SiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for i in range(depth)
#         ])
#         self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
#         self.cnt=0
#         self.z=torch.zeros(2,384)
#         self.initialize_weights()

#     def initialize_weights(self):
#         # Initialize transformer layers:
#         def _basic_init(module):
#             if isinstance(module, nn.Linear):
#                 torch.nn.init.xavier_uniform_(module.weight)
#                 if module.bias is not None:
#                     nn.init.constant_(module.bias, 0)
#         self.apply(_basic_init)

#         # Initialize (and freeze) pos_embed by sin-cos embedding:
#         pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
#         self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))

#         # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
#         w = self.x_embedder.proj.weight.data
#         nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
#         nn.init.constant_(self.x_embedder.proj.bias, 0)

#         # Initialize label embedding table:
#         nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)

#         # Initialize timestep embedding MLP:
#         nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
#         nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)

#         # Zero-out adaLN modulation layers in SiT blocks:
#         for block in self.blocks:
#             nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
#             nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

#         for block in self.aux_blocks:######################33
#             nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
#             nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
#         w = self.x_pure_embedder.proj.weight.data
#         nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
#         nn.init.constant_(self.x_pure_embedder.proj.bias, 0)
#         w = self.x_query_embedder.proj.weight.data
#         nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
#         nn.init.constant_(self.x_query_embedder.proj.bias, 0)

#         # nn.init.constant_(self.mlp[0].weight, 0)
#         # nn.init.constant_(self.mlp[0].bias, 0)
#         # nn.init.constant_(self.mlp[2].weight, 0)
#         # nn.init.constant_(self.mlp[2].bias, 0)

#         # w = self.s_embedder.proj.weight.data
#         # nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
#         # nn.init.constant_(self.s_embedder.proj.bias, 0)
#         # nn.init.constant_(self.cond.weight, 0)
#         # nn.init.constant_(self.cond.bias, 0)########################3

#         # Zero-out output layers:
#         nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
#         nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
#         nn.init.constant_(self.final_layer.linear.weight, 0)
#         nn.init.constant_(self.final_layer.linear.bias, 0)

#     def unpatchify(self, x):
#         """
#         x: (N, T, patch_size**2 * C)
#         imgs: (N, H, W, C)
#         """
#         c = self.out_channels
#         p = self.x_embedder.patch_size[0]
#         h = w = int(x.shape[1] ** 0.5)
#         assert h * w == x.shape[1]

#         x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
#         x = torch.einsum('nhwpqc->nchpwq', x)
#         imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
#         return imgs


#     # def forward(self, x,x_pure, t, y):
#     #     """
#     #     Forward pass of SiT.
#     #     x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
#     #     t: (N,) tensor of diffusion timesteps
#     #     y: (N,) tensor of class labels
#     #     """
#     #     x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#     #     t = self.t_embedder(t)                   # (N, D)
#     #     y = self.y_embedder(y, self.training)    # (N, D)
#     #     c = t + y
#     #     cnt=1                              # (N, D)
#     #     for block in self.blocks:
#     #         x = block(x, c)                      # (N, T, D)
#     #         if cnt==8:
#     #             z_contrast=x
#     #         cnt+=1
#     #     x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)
#     #     x = self.unpatchify(x)                   # (N, out_channels, H, W)
#     #     if self.learn_sigma:
#     #         x, _ = x.chunk(2, dim=1)
#     #     return x,z_contrast
    
#     def forward_dict_train(self, x, x_pure, t, y):############train
#         """
#         Forward pass of SiT.
#         x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
#         t: (N,) tensor of diffusion timesteps
#         y: (N,) tensor of class labels
#         """
#         x_query = self.x_query_embedder(x) + self.pos_embed
#         x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#         x_pure = self.x_pure_embedder(x_pure) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#         t = self.t_embedder(t)                   # (N, D)
#         y = self.y_embedder(y, self.training)    # (N, D)
#         c = t + y                                # (N, D)
        
#         for i in range(0,4):#self.blocks:
#             x_pure = self.aux_blocks[i](x_pure, c) 
#             x_query = self.blocks[i](x_query, c)                     # (N, T, D)

#         x_feature=self.query(x_query,self.dt)

#         for i in range(4,12):#self.blocks:
#             x = self.blocks[i](x, x_feature+c.unsqueeze(1),True)
        
#         x = self.final_layer(x,x_feature+c.unsqueeze(1))               # (N, T, patch_size ** 2 * out_channels)
#         x = self.unpatchify(x)                   # (N, out_channels, H, W)
#         if self.learn_sigma:
#             x, _ = x.chunk(2, dim=1)

#         return x,x,-F.cosine_similarity(x_feature,x_pure,dim=2).mean()
        
#     def forward_dict_eval(self, x, t, y):############eval
#         """
#         Forward pass of SiT.
#         x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
#         t: (N,) tensor of diffusion timesteps
#         y: (N,) tensor of class labels
#         """
#         x_query = self.x_query_embedder(x) + self.pos_embed
#         x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#         t = self.t_embedder(t)                   # (N, D)
#         y = self.y_embedder(y, self.training)    # (N, D)
#         c = t + y                                # (N, D)
        
#         for i in range(0,4):#self.blocks:
#             x_query = self.blocks[i](x_query, c)                     # (N, T, D)

#         x_feature=self.query(x_query,self.dt)

#         for i in range(4,12):#self.blocks:
#             x = self.blocks[i](x, x_feature+c.unsqueeze(1),True)
        
#         x = self.final_layer(x, x_feature+c.unsqueeze(1))               # (N, T, patch_size ** 2 * out_channels)
#         x = self.unpatchify(x)                   # (N, out_channels, H, W)
#         if self.learn_sigma:
#             x, _ = x.chunk(2, dim=1)

#         return x

#     def forward_vq_train(self, x, x_pure, t, y):############train
#         """
#         Forward pass of SiT.
#         x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
#         t: (N,) tensor of diffusion timesteps
#         y: (N,) tensor of class labels
#         """
#         x_query = self.x_query_embedder(x) + self.pos_embed
#         x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#         x_pure = self.x_pure_embedder(x_pure) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#         t = self.t_embedder(t)                   # (N, D)
#         y = self.y_embedder(y, self.training)    # (N, D)
#         c = t + y                                # (N, D)
        
#         for i in range(0,4):#self.blocks:
#             x_pure = self.aux_blocks[i](x_pure, c) 
#             x_query = self.blocks[i](x_query, c)                     # (N, T, D)
        
#         #更新字典
#         dt_grad,vq_loss=self.dt(x_pure)

#         #query
#         x_feature=self.query(x_query,dt_grad.unsqueeze(0))
#         # x_feature=self.dt.query(x_query)

#         for i in range(4,12):#self.blocks:
#             x = self.blocks[i](x, x_feature+c.unsqueeze(1),True)
        
#         x = self.final_layer(x,x_feature+c.unsqueeze(1))               # (N, T, patch_size ** 2 * out_channels)
#         x = self.unpatchify(x)                   # (N, out_channels, H, W)
#         if self.learn_sigma:
#             x, _ = x.chunk(2, dim=1)

#         return x,x,vq_loss

#     def forward_vq_eval(self, x, t, y):############train
#         """
#         Forward pass of SiT.
#         x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
#         t: (N,) tensor of diffusion timesteps
#         y: (N,) tensor of class labels
#         """
#         x_query = self.x_query_embedder(x) + self.pos_embed
#         x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#         # x_pure = self.x_pure_embedder(x_pure) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#         t = self.t_embedder(t)                   # (N, D)
#         y = self.y_embedder(y, self.training)    # (N, D)
#         c = t + y                                # (N, D)
        
#         for i in range(0,4):#self.blocks:
#             # x_pure = self.aux_blocks[i](x_pure, c) 
#             x_query = self.blocks[i](x_query, c)                     # (N, T, D)
        
#         #更新字典
#         # dt_grad,vq_loss=self.dt(x_pure)

#         #query
#         x_feature=self.query(x_query,self.dt.codebook.weight.unsqueeze(0))
#         # x_feature=self.dt.query(x_query)

#         for i in range(4,12):#self.blocks:
#             x = self.blocks[i](x, x_feature+c.unsqueeze(1),True)
        
#         x = self.final_layer(x,x_feature+c.unsqueeze(1))               # (N, T, patch_size ** 2 * out_channels)
#         x = self.unpatchify(x)                   # (N, out_channels, H, W)
#         if self.learn_sigma:
#             x, _ = x.chunk(2, dim=1)

#         return x

#     def forward_codebook(self, x, x_pure, t, y):############train
#         """
#         Forward pass of SiT.
#         x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
#         t: (N,) tensor of diffusion timesteps
#         y: (N,) tensor of class labels
#         """
#         x_query = self.x_query_embedder(x) + self.pos_embed
#         x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#         # x_pure = self.x_pure_embedder(x_pure) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#         t = self.t_embedder(t)                   # (N, D)
#         y = self.y_embedder(y, self.training)    # (N, D)
#         c = t + y                                # (N, D)
        
#         for i in range(0,4):#self.blocks:
#             # x_pure = self.aux_blocks[i](x_pure, c) 
#             x_query = self.blocks[i](x_query, c)                     # (N, T, D)
        
#         #更新字典
#         dt=self.mlp(self.dt.weight.data)

#         # query_vec = F.normalize(x_query, dim=-1)  # [b, token_len, dim]
#         # # 计算余弦相似度
#         # sim_matrix = torch.einsum('btd, kd -> btk', query_vec, dt)
#         # # 软检索 (温度系数0.1)
#         # weights = F.softmax(sim_matrix / 0.1, dim=-1)
#         # # 加权组合字典项
#         # retrieved = torch.einsum('btk, kd -> btd', weights, dt)
#         # #query
#         # x_feature=retrieved
#         x_feature=self.query(x_query,dt.unsqueeze(0))

#         for i in range(4,12):#self.blocks:
#             x = self.blocks[i](x, x_feature+c.unsqueeze(1),True)
        
#         x = self.final_layer(x,x_feature+c.unsqueeze(1))               # (N, T, patch_size ** 2 * out_channels)
#         x = self.unpatchify(x)                   # (N, out_channels, H, W)
#         if self.learn_sigma:
#             x, _ = x.chunk(2, dim=1)

#         return x,x,x

#     def forward1(self, x, x_pure, t, y):############train
#         """
#         Forward pass of SiT.
#         x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
#         t: (N,) tensor of diffusion timesteps
#         y: (N,) tensor of class labels
#         """
#         x_query = self.x_query_embedder(x) + self.pos_embed
#         x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#         x_pure = self.x_pure_embedder(x_pure) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#         t = self.t_embedder(t)                   # (N, D)
#         y = self.y_embedder(y, self.training)    # (N, D)
#         c = t + y                                # (N, D)
        
#         for i in range(0,4):#self.blocks:
#             x_pure = self.aux_blocks[i](x_pure, c) 
#             x_query = self.blocks[i](x_query, c)                     # (N, T, D)
        
#         #更新字典
#         # dt_grad,vq_loss=self.dt(x_pure)

#         #query
#         # x_feature=self.query_dict(x_query,self.dt.unsqueeze(0))
#         x_answer=self.query_teacher(x_query,x_pure)
#         x_feature=self.brain(x_query,x_answer)
#         # x_feature=self.dt.query(x_query)

#         for i in range(4,12):#self.blocks:
#             x = self.blocks[i](x, x_feature+c.unsqueeze(1),True)
        
#         x = self.final_layer(x,x_feature+c.unsqueeze(1))               # (N, T, patch_size ** 2 * out_channels)
#         x = self.unpatchify(x)                   # (N, out_channels, H, W)
#         if self.learn_sigma:
#             x, _ = x.chunk(2, dim=1)

#         return x,x,x#-F.cosine_similarity(x_feature,x_answer,dim=2).mean()

#     def forward(self, x, t, y):############train
#         """
#         Forward pass of SiT.
#         x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
#         t: (N,) tensor of diffusion timesteps
#         y: (N,) tensor of class labels
#         """
#         x_query = self.x_query_embedder(x) + self.pos_embed
#         x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#         # x_pure = self.x_pure_embedder(x_pure) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2
#         t = self.t_embedder(t)                   # (N, D)
#         y = self.y_embedder(y, self.training)    # (N, D)
#         c = t + y                                # (N, D)
        
#         for i in range(0,4):#self.blocks:
#             # x_pure = self.aux_blocks[i](x_pure, c) 
#             x_query = self.blocks[i](x_query, c)                     # (N, T, D)
        
#         #更新字典
#         # dt_grad,vq_loss=self.dt(x_pure)

#         #query
#         # x_feature=self.query_dict(x_query,self.dt.unsqueeze(0))
#         x_answer=torch.zeros_like(x_query)#self.query_teacher(x_query,x_pure)
#         x_feature=self.brain(x_query,x_answer)
#         # x_feature=self.dt.query(x_query)

#         for i in range(4,12):#self.blocks:
#             x = self.blocks[i](x, x_feature+c.unsqueeze(1),True)
        
#         x = self.final_layer(x,x_feature+c.unsqueeze(1))               # (N, T, patch_size ** 2 * out_channels)
#         x = self.unpatchify(x)                   # (N, out_channels, H, W)
#         if self.learn_sigma:
#             x, _ = x.chunk(2, dim=1)

#         return x#((x_feature-x_answer)**2).mean()#-F.cosine_similarity(x_feature,x_answer,dim=2).mean()


#     def forward_with_cfg(self, x, t, y, cfg_scale):
#         """
#         Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
#         """
#         # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
#         half = x[: len(x) // 2]
#         combined = torch.cat([half, half], dim=0)
#         model_out,_ = self.forward(combined, t, y)
#         # print("model_out",model_out)
#         # For exact reproducibility reasons, we apply classifier-free guidance on only
#         # three channels by default. The standard approach to cfg applies it to all channels.
#         # This can be done by uncommenting the following line and commenting-out the line following that.
#         # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
#         eps, rest = model_out[:, :3], model_out[:, 3:]
#         cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
#         half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
#         eps = torch.cat([half_eps, half_eps], dim=0)
#         return torch.cat([eps, rest], dim=1)



#################################################################################
#                   Sine/Cosine Positional Embedding Functions                  #
#################################################################################
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py

def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
    """
    grid_size: int of the grid height and width
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token and extra_tokens > 0:
        pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
    return pos_embed


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    """
    embed_dim: output dimension for each position
    pos: a list of positions to be encoded: size (M,)
    out: (M, D)
    """
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out) # (M, D/2)
    emb_cos = np.cos(out) # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


#################################################################################
#                                   SiT Configs                                  #
#################################################################################

def SiT_XL_2(**kwargs):
    return SiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)

def SiT_XL_4(**kwargs):
    return SiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)

def SiT_XL_8(**kwargs):
    return SiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)

def SiT_L_2(**kwargs):
    return SiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)

def SiT_L_4(**kwargs):
    return SiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)

def SiT_L_8(**kwargs):
    return SiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)

def SiT_B_2(**kwargs):
    return SiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)

def SiT_B_4(**kwargs):
    return SiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)

def SiT_B_8(**kwargs):
    return SiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)

def SiT_S_2(**kwargs):
    return SiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)

def SiT_S_4(**kwargs):
    return SiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)

def SiT_S_8(**kwargs):
    return SiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)


SiT_models = {
    'SiT-XL/2': SiT_XL_2,  'SiT-XL/4': SiT_XL_4,  'SiT-XL/8': SiT_XL_8,
    'SiT-L/2':  SiT_L_2,   'SiT-L/4':  SiT_L_4,   'SiT-L/8':  SiT_L_8,
    'SiT-B/2':  SiT_B_2,   'SiT-B/4':  SiT_B_4,   'SiT-B/8':  SiT_B_8,
    'SiT-S/2':  SiT_S_2,   'SiT-S/4':  SiT_S_4,   'SiT-S/8':  SiT_S_8,
}
